package api

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"strings"
	"time"

	log "github.com/sirupsen/logrus"

	"github.com/slimtoolkit/slim/pkg/vulnerability/epss"
)

const (
	callPath = "https://api.first.org/data/v1/epss"

	trueStr  = "true"
	falseStr = "false"

	// GLOBAL PARAMETERS:

	// Comma-separated list of fieldnames to be retrieved.
	// Used only for limiting the available resultset.
	qsFields = "fields" // type: string

	// Limits the maximun number of records to be shown.
	// Should be a number between 1 and 100.
	qsLimit = "limit" // type: integer

	// Offsets the list of records by this number.
	// The first item is 0.
	qsOffset = "offset" // type: integer

	// Comma-separated list of fieldnames to be used to sort the resultset.
	// Fields starting with - (minus sign) indicate a descending order.
	// Each application should define its default sorting options.
	qsSort = "sort" // type: string

	// Use true, false, 0 or 1. If set to true will add an object wrapping
	// the resultset, with details on the status, total records found,
	// offset and limit. When false this information is returned
	//at the response header. Defaults to true.
	qsEnvelope = "envelope" // type: bool

	// Use true, false, 0 or 1. If the result should be
	// pretty-printed or no. Defaults to false.
	qsPretty = "pretty" // type: bool

	// Collection of fieldnames to retrieve. Affects the resultset and
	// the possible options for the parameter fields. Each data model
	// can specify multiple available scopes.
	qsScope = "scope" // type: string
	// Shows the last 30 days (or what is available)
	// of the EPSS score and percentile for any CVE.
	tsScopeVal = "time-series"
	// Shows the basic EPSS data (cve, epss, percentile, created).
	pubScopeVal = "public" // default scope value

	// EPSS CALL PARAMETERS:

	// Filters by EPSS CVE ID. Multiple values are supported separated by commas.
	// The maximum size accepted for this parameter is 2000 characters (including commas).
	qsCVE           = "cve" // type: string
	maxCVEValueSize = 2000

	// Date in the format YYYY-MM-DD (since April 14, 2021),
	// shows the historic values for epss and percentile attributes.
	qsDate = "date" // type: date

	// Number of days since the EPSS score was added to the database
	// (starting at 1, not affected by the date parameter).
	qsDays = "days" // type: int

	// Only display CVEs with EPSS score greater or equal than the parameter.
	qsEPSSGt = "epss-gt" // type: decimal

	// Only display CVEs with percentile greater or equal than the parameter.
	qsPctGt = "percentile-gt" // type: decimal

	// Only display CVEs with EPSS score lower or equal than the parameter.
	qsEPSSLt = "epss-lt" // type: decimal

	// Only display CVEs with percentile lower or equal than the parameter.
	qsPctLt = "percentile-lt" // type: decimal

	// Free text search at the CVE ID (allows partial matches).
	qsQuery = "q" // type: string

	// note: not fully documented
	qsOrder           = "order"
	orderScoreDescVal = "!epss"
	orderScoreAscVal  = "epss"
	orderPctDescVal   = "!percentile"
	orderPctAscVal    = "percentile"
)

//reminder:
//var _ connector.APIInt   = (*API)(nil)

type Instance struct {
	client   *http.Client
	pretty   bool
	pageSize uint64
	debug    bool
	logger   *log.Entry
}

type Options struct {
	APITimeout int
	Pretty     bool
	PageSize   uint64 //will be used as the default 'limit' value
	Debug      bool
	Logger     *log.Entry
}

type CallOptions struct {
	Date         string
	PageSize     uint64 //will be used as the default 'limit' value
	Offset       uint64
	OutputFields []string
	WithHistory  bool
	Output       interface{}
}

type FilteredCallOptions struct {
	CallOptions
	CveIDPattern   string
	DaysSinceAdded uint
	ScoreGt        float64
	ScoreLt        float64
	PercentileGt   float64
	PercentileLt   float64
	OrderRecords   epss.OrderType
}

func New(options ...Options) *Instance {
	var ref Instance
	var logger *log.Entry

	timeoutSec := time.Duration(epss.APITimeout)
	if len(options) > 0 {
		ref.debug = options[0].Debug
		ref.pretty = options[0].Pretty

		if options[0].APITimeout > 0 {
			timeoutSec = time.Duration(options[0].APITimeout)
		}

		if options[0].PageSize > 0 {
			ref.pageSize = options[0].PageSize
		} else {
			ref.pageSize = epss.PageSize
		}

		logger = options[0].Logger
	}

	if logger == nil {
		logger = log.NewEntry(log.StandardLogger())
	}

	ref.logger = logger.WithField("com", "epss.api")
	ref.client = &http.Client{
		Timeout: timeoutSec * time.Second,
	}

	return &ref
}

func (ref *Instance) ListCall(
	ctx context.Context,
	options ...FilteredCallOptions) (epss.ReplyType, error) {
	var output epss.ReplyType

	if len(options) == 0 {
		options = append(options, FilteredCallOptions{})
	}

	if options[0].Output == nil {
		output = allocOutput(options[0].WithHistory)
		options[0].Output = output
	} else {
		output = options[0].Output.(epss.ReplyType)
	}

	_, err := ref.GenericListCall(ctx, epss.OutJSON, options...)
	if err != nil {
		return nil, err
	}

	return output, nil
}

func (ref *Instance) GenericListCall(
	ctx context.Context,
	format string,
	options ...FilteredCallOptions) (string, error) {
	output := callOutput{
		format: format,
	}

	input := &callInput{}
	if len(options) > 0 {
		if options[0].ScoreGt > 0 {
			input.scoreGt = options[0].ScoreGt
		}

		if options[0].ScoreLt > 0 {
			input.scoreLt = options[0].ScoreLt
		}

		if options[0].PercentileGt > 0 {
			input.percentileGt = options[0].PercentileGt
		}

		if options[0].PercentileLt > 0 {
			input.percentileLt = options[0].PercentileLt
		}

		if options[0].DaysSinceAdded > 0 {
			input.days = options[0].DaysSinceAdded
		}

		if options[0].CveIDPattern != "" {
			input.query = options[0].CveIDPattern
		}

		if options[0].OrderRecords != epss.NoOrder {
			switch options[0].OrderRecords {
			case epss.ScoreDescOrder:
				input.order = orderScoreDescVal
			case epss.ScoreAscOrder:
				input.order = orderScoreAscVal
			case epss.PercentileDescOrder:
				input.order = orderPctDescVal
			case epss.PercentileAscOrder:
				input.order = orderPctAscVal
			}
		}

		input.offset = options[0].Offset
		if options[0].PageSize > 0 {
			input.limit = options[0].PageSize
		} else {
			input.limit = ref.pageSize
		}

		input.date = options[0].Date
		input.fields = options[0].OutputFields
		input.history = options[0].WithHistory

		if options[0].Output != nil {
			output.decoded = options[0].Output
		}
	}

	_, err := ref.call(ctx, input, &output)
	if err != nil {
		return "", err
	}

	return output.raw, nil
}

func allocOutput(withHistory bool) epss.ReplyType {
	var output epss.ReplyType

	if withHistory {
		output = &epss.APIResultWithHistory{
			Reply: epss.Reply{
				ReplyMetadata: &epss.ReplyMetadata{},
			},
		}
	} else {
		output = &epss.APIResult{
			Reply: epss.Reply{
				ReplyMetadata: &epss.ReplyMetadata{},
			},
		}
	}

	return output
}

func (ref *Instance) LookupCall(
	ctx context.Context,
	cveIDs []string,
	options ...CallOptions) (epss.ReplyType, error) {
	var output epss.ReplyType

	if len(options) == 0 {
		options = append(options, CallOptions{})
	}

	if options[0].Output == nil {
		output = allocOutput(options[0].WithHistory)
		options[0].Output = output
	} else {
		output = options[0].Output.(epss.ReplyType)
	}

	_, err := ref.GenericLookupCall(ctx, cveIDs, epss.OutJSON, options...)
	if err != nil {
		return nil, err
	}

	return output, nil
}

func (ref *Instance) GenericLookupCall(
	ctx context.Context,
	cveIDs []string,
	format string,
	options ...CallOptions) (string, error) {
	output := callOutput{
		format: format,
	}

	input := &callInput{
		cveList: cveIDs,
	}

	if len(options) > 0 {
		input.offset = options[0].Offset
		if options[0].PageSize > 0 {
			input.limit = options[0].PageSize
		} else {
			input.limit = ref.pageSize
		}

		input.date = options[0].Date
		input.fields = options[0].OutputFields
		input.history = options[0].WithHistory

		if options[0].Output != nil {
			output.decoded = options[0].Output
		}
	}

	_, err := ref.call(ctx,
		input,
		&output)
	if err != nil {
		return "", err
	}

	return output.raw, nil
}

type callInput struct {
	cve     string
	cveList []string
	offset  uint64
	limit   uint64
	date    string
	fields  []string
	history bool
	//filter fields
	order        string
	query        string
	scoreGt      float64
	scoreLt      float64
	percentileGt float64
	percentileLt float64
	days         uint
}

type callOutput struct {
	format      string
	decoded     interface{}
	decodedOnly bool
	raw         string
}

func (ref *Instance) call(
	ctx context.Context,
	input *callInput,
	output *callOutput) (*http.Response, error) {
	logger := ref.logger.WithField("op", "epss.api.Instance.call")
	outFormat := epss.OutJSON
	if output != nil {
		if output.format == "" {
			output.format = outFormat
		}

		if !epss.IsValidOutput(output.format) {
			logger.WithFields(log.Fields{
				"error":         epss.ErrInvalidParams,
				"output.format": output.format,
			}).Error("epss.IsValidOutput")
			return nil, epss.ErrInvalidParams
		}

		outFormat = output.format
	}

	var body io.Reader //no call param to put in the body (yet)
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, callPath, body)
	if err != nil {
		logger.WithError(err).Error("http.NewRequestWithContext")
		return nil, err
	}

	req.Header.Set("Accept", outFormat)

	qs := url.Values{}
	//always setting envelope because response headers don't include all fields
	qs.Set(qsEnvelope, trueStr)

	if ref.pretty {
		qs.Set(qsPretty, trueStr)
	}

	if input != nil {
		if input.query != "" {
			qs.Set(qsQuery, input.query)
		}

		if input.cve != "" {
			input.cveList = append(input.cveList, input.cve)
		}

		if len(input.cveList) > 0 {
			if err := epss.IsValidCveList(input.cveList); err != nil {
				logger.WithFields(log.Fields{
					"error":         err,
					"input.cveList": input.cveList,
				}).Error("epss.IsValidCveList")
				return nil, err
			}

			cveValue := strings.Join(input.cveList, ",")
			if len(cveValue) > maxCVEValueSize {
				logger.WithFields(log.Fields{
					"error":         epss.ErrTooManyCVEs,
					"input.cveList": input.cveList,
				}).Error("maxCVEValueSize")
				return nil, epss.ErrTooManyCVEs
			}

			qs.Set(qsCVE, cveValue)
		}

		if input.date != "" {
			date, err := epss.DateFromString(input.date)
			if err != nil {
				logger.WithFields(log.Fields{
					"error":      err,
					"input.date": input.date,
				}).Error("epss.DateFromString")
				return nil, err
			}

			if !date.IsZero() {
				if !epss.IsValidDate(date) {
					logger.WithFields(log.Fields{
						"error": epss.ErrInvalidDateParam,
						"date":  date,
					}).Error("epss.IsValidDate")
					return nil, epss.ErrInvalidDateParam
				}

				qs.Set(qsDate, input.date)
			}
		}

		if len(input.fields) > 0 {
			//todo: add field name validation
			qs.Set(qsFields, strings.Join(input.fields, ","))
		}

		if input.history {
			qs.Set(qsScope, tsScopeVal)
		}

		if input.order != "" {
			qs.Set(qsOrder, input.order)
		}

		if input.scoreGt > 0 {
			qs.Set(qsEPSSGt, fmt.Sprintf("%v", input.scoreGt))
		}

		if input.scoreLt > 0 {
			qs.Set(qsEPSSLt, fmt.Sprintf("%v", input.scoreLt))
		}

		if input.percentileGt > 0 {
			qs.Set(qsPctGt, fmt.Sprintf("%v", input.percentileGt))
		}

		if input.percentileLt > 0 {
			qs.Set(qsPctLt, fmt.Sprintf("%v", input.percentileLt))
		}

		if input.days > 0 {
			qs.Set(qsDays, fmt.Sprintf("%v", input.days))
		}

		if input.limit > 0 {
			qs.Set(qsLimit, fmt.Sprintf("%v", input.limit))
		}

		if input.offset > 0 {
			qs.Set(qsOffset, fmt.Sprintf("%v", input.offset))
		}
	}

	req.URL.RawQuery = qs.Encode()

	ref.logger.WithFields(log.Fields{
		"path": callPath,
		"qs":   req.URL.RawQuery,
	}).Trace("ref.client.Do")

	resp, err := ref.client.Do(req)
	if resp != nil && resp.Body != nil {
		if output != nil {
			defer resp.Body.Close()
		}

		if err != nil {
			logger.WithError(err).Error("ref.client.Do")
			return resp, err
		}

		if resp.StatusCode != http.StatusOK {
			logger.WithField("status.code", resp.StatusCode).Error("ref.client.Do")

			if resp.StatusCode == http.StatusNotFound {
				return resp, epss.ErrNotFound
			}
			if resp.StatusCode == http.StatusForbidden {
				return resp, epss.ErrNotAuthorized
			}

			return resp, fmt.Errorf("bad http status - %d", resp.StatusCode)
		}

		if output != nil {
			var b bytes.Buffer
			b.ReadFrom(resp.Body)

			if output.decoded != nil && outFormat == epss.OutJSON {
				//non-json responses are returned as raw strings
				decoder := json.NewDecoder(bytes.NewReader(b.Bytes()))
				err = decoder.Decode(output.decoded)
				if err != nil {
					logger.WithFields(log.Fields{
						"error":          err,
						"output.decoded": output.decoded,
					}).Error("decoder.Decode")
					return resp, err
				}
			}

			if output.decoded == nil || !output.decodedOnly {
				output.raw = b.String()
			}
		}

		return resp, nil
	}

	return resp, err
}
